{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Jupyter Utilities Demo\n",
    "\n",
    "MONAI comes with an expanding set of useful utilities for Jupyter notebooks. This demo notebook illustrates their use with a toy segmentation problem. One challenge of Jupyter notebooks is the connection between the browser window and a running cell. Later versions have improved behaviour but there still is a loss of output from running cells if the browser window is closed. Furthermore, while a cell is running no other calls can run, so running code to inspect the training process mid-run isn't possible. \n",
    "\n",
    "MONAI's solution is to provide a convenient method for wrapping Ignite `Engine`-derived classes in a thread so that training can happen asynchronously from the notebook's main thread. This allows cells to complete immediately but allow training to continue in the background. Output to stdout/stderr is still (mostly) captured by that cell or whichever cell is subsequently run, but also allows the browser window to be closed or other cells to be run. The wrapper, `ThreadContainer` provides some methods for inspecting the status of the run and plotting some simple outputs. This is particularly useful if you cannot or simply don't want to use external tools like Tensorboard to monitor your progress."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0.6.0+1.g8365443a\n",
      "Numpy version: 1.20.3\n",
      "Pytorch version: 1.9.0a0+c3d40fd\n",
      "MONAI flags: HAS_EXT = True, USE_COMPILED = False\n",
      "MONAI rev id: 8365443ababac313340467e5987c7babe2b5b86a\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.5\n",
      "Nibabel version: 3.2.1\n",
      "scikit-image version: 0.15.0\n",
      "Pillow version: 8.2.0\n",
      "Tensorboard version: 2.5.0\n",
      "gdown version: 3.13.0\n",
      "TorchVision version: 0.10.0a0\n",
      "ITK version: 5.1.2\n",
      "tqdm version: 4.53.0\n",
      "lmdb version: 1.2.1\n",
      "psutil version: 5.8.0\n",
      "pandas version: 1.1.4\n",
      "einops version: 0.3.0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import monai\n",
    "from monai.data import Dataset, DataLoader, create_test_image_2d\n",
    "from monai.losses import DiceLoss\n",
    "from monai.networks.nets import UNet\n",
    "from monai.transforms import AddChanneld, Compose, EnsureTyped, AsDiscreted\n",
    "from monai.utils import ThreadContainer\n",
    "from monai.engines import SupervisedTrainer, SupervisedEvaluator\n",
    "from monai.utils.enums import CommonKeys\n",
    "from monai.handlers import MeanDice, ValidationHandler, MetricLogger, from_engine\n",
    "\n",
    "monai.config.print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create toy dataset\n",
    "\n",
    "We'll create a set of image/segmentation pairs with some noise added since the data doesn't really matter for this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "keys = (CommonKeys.IMAGE, CommonKeys.LABEL)\n",
    "data = []\n",
    "\n",
    "for i in range(300):\n",
    "    rs = np.random.RandomState(i)\n",
    "    im, seg = create_test_image_2d(256, 256, num_seg_classes=1, noise_max=0.75, random_state=rs)\n",
    "    data.append({keys[0]: im, keys[1]: seg})\n",
    "\n",
    "trans = Compose([AddChanneld(keys=keys), EnsureTyped(keys=keys)])\n",
    "\n",
    "train_ds = Dataset(data[:240], trans)\n",
    "val_ds = Dataset(data[240:], trans)\n",
    "train_loader = DataLoader(train_ds, batch_size=10, shuffle=True, num_workers=0)\n",
    "val_loader = DataLoader(val_ds, batch_size=10, shuffle=True, num_workers=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create network, loss, optimizer\n",
    "\n",
    "We'll choose a sub-optimal network configuration here so that training isn't instantaneous:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create UNet, DiceLoss and Adam optimizer\n",
    "device = torch.device(\"cuda:0\")\n",
    "net = UNet(\n",
    "    spatial_dims=2,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(8, 16, 32),\n",
    "    strides=(2, 2)\n",
    ").to(device)\n",
    "\n",
    "lossfn = DiceLoss(sigmoid=True)\n",
    "lr = 1e-3\n",
    "opt = torch.optim.Adam(net.parameters(), lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create trainer and start running in a separate thread\n",
    "\n",
    "Everything here is standard MONAI code for a simple supervised training run with an evaluator computing a mean dice metric:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs = 10\n",
    "\n",
    "val_post_transforms = Compose(\n",
    "    [\n",
    "        EnsureTyped(keys=\"pred\"),\n",
    "        AsDiscreted(keys=\"pred\", threshold_values=True, logit_thresh=0.0),\n",
    "    ]\n",
    ")\n",
    "\n",
    "evaluator = SupervisedEvaluator(\n",
    "    device=device,\n",
    "    val_data_loader=val_loader,\n",
    "    network=net,\n",
    "    postprocessing=val_post_transforms,\n",
    "    key_val_metric={\n",
    "        \"val_mean_dice\": MeanDice(\n",
    "            include_background=True,\n",
    "            output_transform=from_engine([CommonKeys.PRED, CommonKeys.LABEL]),\n",
    "        )\n",
    "    },\n",
    ")\n",
    "\n",
    "logger = MetricLogger(evaluator=evaluator)\n",
    "\n",
    "train_handlers = [logger, ValidationHandler(1, evaluator)]\n",
    "\n",
    "trainer = SupervisedTrainer(\n",
    "    device=device,\n",
    "    max_epochs=max_epochs,\n",
    "    network=net,\n",
    "    optimizer=opt,\n",
    "    loss_function=lossfn,\n",
    "    train_data_loader=train_loader,\n",
    "    train_handlers=train_handlers,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The trainer hasn't been started however so the cell above completes immediately. Next we wrap the trainer in a `ThreadContainer` object which inherits from `Thread`. When this thread is started the provided engine (`trainer` in this case) is run within its control flow:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "con = ThreadContainer(trainer)\n",
    "con.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This cell also completes immediately with the engine continuing in the background. With this running the browser window can be closed safely after saving the notebook and the run will continue in the background. This may not work as well on platforms like Colab which might timeout notebook sessions, but standard Jupyter servers running on remote machines or containers are compatible.\n",
    "\n",
    "The status of the engine can be queried as such:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Stopped, Iters: 0/24, Epochs: 20/20, Loss: 0.05168'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "con.status()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This method and others of `ThreadContainer` are thread-safe in that they will acquire a lock preventing the engine thread from continuing until they complete. This ensures their information is synchronized with the actual status of the engine.\n",
    "\n",
    "At any time during the run a simple plot can be drawn of the loss over time and metrics, using the `MetricLogger` object as the source of information:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "con.plot_status(logger)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The run can be stopped at any time with `stop` which calls `terminate` on the wrapped engine and then joins with the thread:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "con.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
